package ml.combust.mleap.core.feature

import ml.combust.mleap.core.types.{StructType, TensorType}
import ml.combust.mleap.tensor.{DenseTensor, Tensor}
import org.apache.spark.ml.linalg.mleap.BLAS
import org.apache.spark.ml.linalg.{Vector, Vectors}

/**
  * Created by hollinwilkins on 12/28/16.
  */
case class BucketedRandomProjectionLSHModel(randomUnitVectors: Seq[Vector],
                                            bucketLength: Double,
                                            inputSize: Int) extends LSHModel {
  def apply(features: Vector): Tensor[Double] = predict(features)
  def predict(features: Vector): Tensor[Double] = {
    val hashValues: Seq[Double] = randomUnitVectors.map({
      randUnitVector => Math.floor(BLAS.dot(features, randUnitVector) / bucketLength)
    })

    // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
    DenseTensor(hashValues.toArray, Seq(hashValues.length, 1))
  }

  override def keyDistance(x: Vector, y: Vector): Double = {
    Math.sqrt(Vectors.sqdist(x, y))
  }

  override def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
    // Since it's generated by hashing, it will be a pair of dense vectors.
    x.zip(y).map(vectorPair => Vectors.sqdist(vectorPair._1, vectorPair._2)).min
  }

  override def inputSchema: StructType = StructType("input" -> TensorType.Double(inputSize)).get

  override def outputSchema: StructType = StructType("output" -> TensorType.Double(inputSize, 1)).get
}
